name: Performance Benchmarks with NASLib

on:
  workflow_dispatch:
    inputs:
      sampler-list:
        description: 'Sampler List: A list of samplers to check the performance. Should be a whitespace-separated list of Optuna samplers. Each sampler must exist under `optuna.samplers` or `optuna.integration`.'
        required: false
        default: 'RandomSampler TPESampler'
      sampler-kwargs-list:
        description: 'Sampler Arguments List: A list of sampler keyword arguments. Should be a whitespace-separated list of json format dictionaries.'
        required: false
        default: '{} {\"multivariate\":true,\"constant_liar\":true}'
      pruner-list:
        description: 'Pruner List: A list of pruners to check the performance. Should be a whitespace-separated list of Optuna pruners. Each pruner must exist under `optuna.pruners`.'
        required: false
        default: 'NopPruner'
      pruner-kwargs-list:
        description: 'Pruner Arguments List: A list of pruner keyword arguments. Should be a whitespace-separated list of json format dictionaries.'
        required: false
        default: '{}'
      budget:
        description: 'Number of Trials if the pruning is not enabled. If the pruning is enabled, the total number of steps is equal to `budget * (steps per trial)`.'
        required: false
        default: '100'
      n-runs:
        description: 'Number of Studies'
        required: false
        default: '10'
      n-concurrency:
        description: 'Number of Concurrent Trials'
        required: false
        default: '1'


jobs:
  performance-benchmarks-naslib:
    runs-on: ubuntu-latest

    steps:
    - name: Checkout
      uses: actions/checkout@v4

    - name: Setup Python3.8
      uses: actions/setup-python@v5
      with:
        python-version: 3.8

    - name: Install gnuplot
      run: |
        sudo apt update
        sudo apt -y install gnuplot

    - name: Setup cache
      uses: actions/cache@v3
      env:
        # Caches them under a common name so that they can be used by other performance benchmark.
        cache-name: performance-benchmarks
      with:
        path: ~/.cache/pip
        key: ${{ runner.os }}-3.8-${{ env.cache-name }}-${{ hashFiles('**/pyproject.toml') }}-v1
        restore-keys: |
          ${{ runner.os }}-3.8-${{ env.cache-name }}-${{ hashFiles('**/pyproject.toml') }}

    - name: Install Python libralies
      run: |
        python -m pip install --upgrade pip
        pip install --progress-bar off -U setuptools

        # Install minimal dependencies and confirm that `import optuna` is successful.
        pip install --progress-bar off .
        python -c 'import optuna'
        optuna --version

        pip install --progress-bar off kurobako

    - name: Cache kurobako CLI
      id: cache-kurobako
      uses: actions/cache@v3
      with:
        path: ./kurobako
        key: kurobako-0-2-10

    - name: Install kurobako CLI
      if: steps.cache-kurobako.outputs.cache-hit != 'true'
      run: |

        curl -L https://github.com/optuna/kurobako/releases/download/0.2.10/kurobako-0.2.10.linux-amd64 -o kurobako
        chmod +x kurobako
        ./kurobako -h

    - name: Install naslib
      run: |
        curl -L https://github.com/automl/NASLib/archive/refs/heads/Develop.zip -o NASLib-Develop.zip
        unzip NASLib-Develop.zip
        mv NASLib-Develop NASLib
        cd NASLib
        pip install --upgrade pip setuptools wheel
        pip install -e .
        cd ..

    - name: Output installed packages
      run: |
        pip freeze --all
    - name: Output dependency tree
      run: |
        pip install pipdeptree
        pipdeptree

    - name: Cache nasbench201-cifar10 dataset
      id: cache-nb201-cifar10-dataset
      uses: actions/cache@v3
      with:
        path: NASLib/naslib/data/nb201_cifar10_full_training.pickle
        key: cache-nb201-cifar10

    - name: Download nasbench201-cifar10 dataset
      if: steps.cache-nb201-cifar10-dataset.outputs.cache-hit != 'true'
      run: |

        cd NASLib/naslib/data
        function wget_gdrive {

          GDRIVE_FILE_ID=$1
          DEST_PATH=$2

          wget --save-cookies cookies.txt 'https://docs.google.com/uc?export=download&id='$GDRIVE_FILE_ID -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1/p' > confirm.txt
          wget --load-cookies cookies.txt -O $DEST_PATH 'https://docs.google.com/uc?export=download&id='$GDRIVE_FILE_ID'&confirm='$(<confirm.txt)
          rm -fr cookies.txt confirm.txt
        }
        wget_gdrive 1sh8pEhdrgZ97-VFBVL94rI36gedExVgJ nb201_cifar10_full_training.pickle
        cd ../../../

    - name: Cache nasbench201-cifar100 dataset
      id: cache-nb201-cifar100-dataset
      uses: actions/cache@v3
      with:
        path: NASLib/naslib/data/nb201_cifar100_full_training.pickle
        key: cache-nb201-cifar100

    - name: Download nasbench201-cifar100 dataset
      if: steps.cache-nb201-cifar100-dataset.outputs.cache-hit != 'true'
      run: |

        cd NASLib/naslib/data
        function wget_gdrive {

          GDRIVE_FILE_ID=$1
          DEST_PATH=$2

          wget --save-cookies cookies.txt 'https://docs.google.com/uc?export=download&id='$GDRIVE_FILE_ID -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1/p' > confirm.txt
          wget --load-cookies cookies.txt -O $DEST_PATH 'https://docs.google.com/uc?export=download&id='$GDRIVE_FILE_ID'&confirm='$(<confirm.txt)
          rm -fr cookies.txt confirm.txt
        }
        wget_gdrive 1hV6-mCUKInIK1iqZ0jfBkcKaFmftlBtp nb201_cifar100_full_training.pickle
        cd ../../../

    - name: Cache nasbench201-imagenet16 dataset
      id: cache-nb201-imagenet16-dataset
      uses: actions/cache@v3
      with:
        path: NASLib/naslib/data/nb201_ImageNet16_full_training.pickle
        key: cache-nb201-imagenet16

    - name: Download nasbench201-imagenet16 dataset
      if: steps.cache-nb201-imagenet16-dataset.outputs.cache-hit != 'true'
      run: |

        cd NASLib/naslib/data
        function wget_gdrive {

          GDRIVE_FILE_ID=$1
          DEST_PATH=$2

          wget --save-cookies cookies.txt 'https://docs.google.com/uc?export=download&id='$GDRIVE_FILE_ID -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1/p' > confirm.txt
          wget --load-cookies cookies.txt -O $DEST_PATH 'https://docs.google.com/uc?export=download&id='$GDRIVE_FILE_ID'&confirm='$(<confirm.txt)
          rm -fr cookies.txt confirm.txt
        }
        wget_gdrive 1FVCn54aQwD6X6NazaIZ_yjhj47mOGdIH nb201_ImageNet16_full_training.pickle
        cd ../../../


    - name: Run performance benchmark
      run: |
        python benchmarks/run_naslib.py \
          --path-to-kurobako "." \
          --name-prefix "" \
          --budget ${{ github.event.inputs.budget }} \
          --n-runs ${{ github.event.inputs.n-runs }} \
          --n-jobs 6 \
          --n-concurrency ${{ github.event.inputs.n-concurrency }} \
          --sampler-list '${{ github.event.inputs.sampler-list }}' \
          --sampler-kwargs-list '${{ github.event.inputs.sampler-kwargs-list }}' \
          --pruner-list '${{ github.event.inputs.pruner-list }}' \
          --pruner-kwargs-list '${{ github.event.inputs.pruner-kwargs-list }}' \
          --seed 0 \
          --out-dir "out"

    - uses: actions/upload-artifact@v3
      with:
        name: benchmark-report
        path: |
          out/results.json
          out/report.md
          out/**/*.png

    - uses: actions/download-artifact@v2
      with:
        name: benchmark-report
        path: |
          out/results.json
          out/report.md
          out/**/*.png
